import pandas as pd
import numpy as np
import json
from sklearn.metrics import precision_recall_fscore_support, roc_curve, auc
from sklearn.metrics import hamming_loss, accuracy_score, multilabel_confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import time
import os
from pathlib import Path
import traceback

# Set global font sizes at the start
plt.rcParams['font.size'] = 14  # Base font size
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12
plt.rcParams['legend.fontsize'] = 12

def get_model_colors():
    """Get consistent color mapping for models"""
    return {
        'ConflLlama-Standard': '#2ca02c',  # Green
        'ConflLlama-Alt': '#d62728'        # Red
    }

def print_progress(message, indent=0):
    """Print formatted progress message with timestamp"""
    indent_str = "  " * indent
    timestamp = time.strftime("%H:%M:%S", time.localtime())
    print(f"[{timestamp}] {indent_str}{message}")

def sanitize_filename(name):
    """Convert model name to valid filename"""
    invalid_chars = [':', '.', '/', '\\', ' ']
    filename = name
    for char in invalid_chars:
        filename = filename.replace(char, '_')
    return filename

def parse_prediction(pred_str, threshold=0.5):
    """Parse prediction from various formats"""
    try:
        if isinstance(pred_str, str):
            try:
                pred_dict = json.loads(pred_str.replace("'", '"'))
                if isinstance(pred_dict, dict):
                    valid_preds = [(k, float(v)) for k, v in pred_dict.items() 
                                 if v is not None and str(v).replace('.', '').isdigit()]
                    if valid_preds:
                        valid_preds.sort(key=lambda x: x[1], reverse=True)
                        return valid_preds[0][0], [pred[0] for pred in valid_preds if pred[1] >= threshold]
            except json.JSONDecodeError:
                if "|" in pred_str:
                    labels = [label.strip() for label in pred_str.split("|")]
                    valid_labels = [label for label in labels if label and label != "Unknown"]
                    if valid_labels:
                        return valid_labels[0], valid_labels
                clean_label = pred_str.strip()
                if clean_label and clean_label != "Unknown":
                    return clean_label, [clean_label]
        return "Unknown", ["Unknown"]
    except Exception as e:
        print_progress(f"Error parsing prediction: {str(e)}", indent=2)
        return "Unknown", ["Unknown"]

def get_gold_standard_labels(row):
    """Extract gold standard labels"""
    labels = []
    for col in ['attacktype1_txt', 'attacktype2_txt', 'attacktype3_txt']:
        if pd.notna(row[col]) and row[col] != "Unknown":
            labels.append(row[col])
    return labels

def calculate_single_label_metrics(y_true, y_pred, labels):
    """Calculate metrics for single-label classification"""
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, 
                                                             labels=labels, 
                                                             zero_division=0)
    accuracy = accuracy_score(y_true, y_pred)
    
    metrics = {}
    for i, label in enumerate(labels):
        metrics[label] = {
            'Precision': precision[i],
            'Recall': recall[i],
            'F1-score': f1[i]
        }
    
    metrics['macro_avg'] = {
        'Precision': np.mean(precision),
        'Recall': np.mean(recall),
        'F1-score': np.mean(f1),
        'Accuracy': accuracy
    }
    
    return metrics

def calculate_multi_label_metrics(true_labels_list, pred_labels_list, unique_labels):
    """Calculate metrics for multi-label classification"""
    y_true = np.zeros((len(true_labels_list), len(unique_labels)))
    y_pred = np.zeros((len(pred_labels_list), len(unique_labels)))
    
    for i, (true_set, pred_set) in enumerate(zip(true_labels_list, pred_labels_list)):
        for label in true_set:
            if label in unique_labels:
                y_true[i, unique_labels.index(label)] = 1
        for label in pred_set:
            if label in unique_labels:
                y_pred[i, unique_labels.index(label)] = 1
    
    metrics = {}
    hamming = hamming_loss(y_true, y_pred)
    confusion_matrices = multilabel_confusion_matrix(y_true, y_pred)
    
    for i, label in enumerate(unique_labels):
        tn, fp, fn, tp = confusion_matrices[i].ravel()
        metrics[label] = {
            'Precision': tp / (tp + fp) if (tp + fp) > 0 else 0,
            'Recall': tp / (tp + fn) if (tp + fn) > 0 else 0,
            'F1-score': 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
        }
    
    subset_accuracy = accuracy_score(y_true, y_pred)
    partial_matches = sum(np.sum(y_true[i] * y_pred[i]) > 0 for i in range(len(y_true)))
    partial_accuracy = partial_matches / len(y_true)
    
    metrics['overall'] = {
        'Hamming_Loss': hamming,
        'Subset_Accuracy': subset_accuracy,
        'Partial_Match_Accuracy': partial_accuracy,
        'Label_Cardinality_True': np.mean(np.sum(y_true, axis=1)),
        'Label_Cardinality_Predicted': np.mean(np.sum(y_pred, axis=1))
    }
    
    return metrics

def get_probabilities_for_label(predictions, label):
    """Get probability for a specific label"""
    try:
        if isinstance(predictions, str):
            try:
                pred_dict = json.loads(predictions.replace("'", '"'))
                return float(pred_dict.get(label, 0))
            except json.JSONDecodeError:
                if "|" in predictions:
                    labels = [l.strip() for l in predictions.split("|")]
                    return 1.0 if label in labels else 0.0
                return 1.0 if predictions.strip() == label else 0.0
        elif isinstance(predictions, dict):
            return float(predictions.get(label, 0))
        return 0.0
    except:
        return 0.0

def calculate_roc_curves(df, model_cols, true_col, unique_labels):
    """Calculate ROC curves for all models"""
    roc_data = {}
    class_roc_data = {}
    
    for model_col in model_cols:
        print_progress(f"Processing ROC curve for {model_col}...", indent=2)
        all_fpr = []
        all_tpr = []
        all_auc = []
        class_data = {}
        
        for label in unique_labels:
            y_true = (df[true_col] == label).astype(int)
            y_score = df[model_col].apply(lambda x: get_probabilities_for_label(x, label))
            
            try:
                fpr, tpr, _ = roc_curve(y_true, y_score)
                roc_auc = auc(fpr, tpr)
                
                all_fpr.append(fpr)
                all_tpr.append(tpr)
                all_auc.append(roc_auc)
                
                class_data[label] = {
                    'fpr': fpr,
                    'tpr': tpr,
                    'auc': roc_auc
                }
            except Exception as e:
                print_progress(f"Error calculating ROC for {label}: {str(e)}", indent=3)
                continue
        
        if all_auc:
            mean_auc = np.mean(all_auc)
            base_fpr = np.linspace(0, 1, 100)
            mean_tpr = np.mean([np.interp(base_fpr, fpr, tpr) 
                              for fpr, tpr in zip(all_fpr, all_tpr)], axis=0)
            
            roc_data[model_col] = {
                'fpr': base_fpr,
                'tpr': mean_tpr,
                'auc': mean_auc
            }
            class_roc_data[model_col] = class_data
    
    return roc_data, class_roc_data

def plot_roc_curves(roc_data, model_name_mapping, colors, output_dir):
    """Plot ROC curves with consistent colors"""
    print_progress("Creating overall ROC curve plot...", indent=1)
    plt.figure(figsize=(12, 10))  # Increased figure size
    
    plt.plot([0, 1], [0, 1], linestyle='--', color='black', alpha=0.8, label='Random', linewidth=2)
    
    for model_col, data in roc_data.items():
        base_model_name = model_col.replace('_predictions', '')
        display_name = model_name_mapping.get(base_model_name, base_model_name)
        color = colors.get(display_name, 'gray')
        
        plt.plot(data['fpr'], data['tpr'], 
                label=f"{display_name} (AUC = {data['auc']:.3f})",
                color=color,
                linewidth=3)  # Increased line width
    
    plt.grid(True, alpha=0.3)
    plt.xlabel('False Positive Rate', fontsize=14, labelpad=10)
    plt.ylabel('True Positive Rate', fontsize=14, labelpad=10)
    plt.legend(loc='lower right', fontsize=12, framealpha=0.9)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.tight_layout()
    
    save_path = os.path.join(output_dir, 'overall_roc_curves.png')
    print_progress(f"Saving overall ROC curve plot to {save_path}", indent=2)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_class_roc_curves(class_roc_data, model_name_mapping, colors, unique_labels, output_dir):
    """Plot class-wise ROC curves"""
    print_progress("Creating class-wise ROC curve plots...", indent=1)
    
    class_roc_dir = os.path.join(output_dir, 'class_roc_curves')
    os.makedirs(class_roc_dir, exist_ok=True)
    
    for label in unique_labels:
        plt.figure(figsize=(10, 8))
        plt.plot([0, 1], [0, 1], linestyle='--', color='black', alpha=0.8, label='Random')
        
        for model_col, class_data in class_roc_data.items():
            if label in class_data:
                base_model_name = model_col.replace('_predictions', '')
                display_name = model_name_mapping.get(base_model_name, base_model_name)
                color = colors.get(display_name, 'gray')
                
                data = class_data[label]
                plt.plot(data['fpr'], data['tpr'],
                        label=f"{display_name} (AUC = {data['auc']:.3f})",
                        color=color,
                        linewidth=2)
        
        plt.grid(True, alpha=0.3)
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'ROC Curves for {label}')
        plt.legend(loc='lower right')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        
        safe_label = sanitize_filename(label)
        save_path = os.path.join(class_roc_dir, f'roc_curves_{safe_label}.png')
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

def plot_confusion_matrix(y_true, y_pred, labels, title, filename):
    """Plot confusion matrix"""
    cm = pd.crosstab(y_true, y_pred, normalize='index')
    
    for label in labels:
        if label not in cm.index:
            cm.loc[label] = 0
        if label not in cm.columns:
            cm[label] = 0
    
    cm = cm.reindex(index=labels, columns=labels, fill_value=0)
    
    plt.figure(figsize=(18, 14))  # Increased figure size
    
    # Create heatmap with basic settings
    heatmap = sns.heatmap(cm, annot=True, fmt='.3f', cmap='YlOrRd', square=True,
                         cbar_kws={'label': 'Normalized Frequency'},
                         annot_kws={'size': 10})  # Increased annotation size
    
    # Adjust colorbar label size after plotting
    heatmap.collections[0].colorbar.ax.set_ylabel('Normalized Frequency', size=14)
    
    plt.ylabel('True Label', size=14, labelpad=10)
    plt.xlabel('Predicted Label', size=14, labelpad=10)
    plt.xticks(rotation=45, ha='right', size=12)
    plt.yticks(rotation=0, size=12)
    
    plt.tight_layout()
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()

def plot_temporal_metrics(temporal_df, output_dir):
    """Create separate plots for different metric groups"""
    metric_groups = {
        'multilabel': ['Hamming_Loss', 'Subset_Accuracy', 'Partial_Match_Accuracy'],
        'macro': ['Macro_Precision', 'Macro_Recall', 'Macro_F1-score', 'Macro_Accuracy']
    }
    
    colors = get_model_colors()
    
    for group_name, metrics in metric_groups.items():
        plt.figure(figsize=(14, 8))  # Increased figure size
        
        for metric in metrics:
            plt.plot(temporal_df.index, temporal_df[metric], 
                    marker='o', label=metric.replace('_', ' '),
                    linewidth=3, markersize=8)  # Increased line width and marker size
        
        plt.xlabel('Model Version', fontsize=14, labelpad=10)
        plt.ylabel('Score', fontsize=14, labelpad=10)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend(fontsize=12, framealpha=0.9)
        plt.xticks(rotation=45, size=12)
        plt.yticks(size=12)
        plt.tight_layout()
        
        plt.savefig(os.path.join(output_dir, f'{group_name}_progression.png'), 
                    dpi=300, bbox_inches='tight')
        plt.close()

def plot_temporal_progression(temporal_df, output_dir):
    """Plot metrics progression over temporal models"""
    plt.figure(figsize=(14, 8))  # Increased figure size
    colors = get_model_colors()
    
    for column in temporal_df.columns:
        plt.plot(temporal_df.index, temporal_df[column], 
                marker='o', label=column.replace('_', ' '),
                linewidth=3, markersize=8)  # Increased line width and marker size
    
    plt.xlabel('Model Version', fontsize=14, labelpad=10)
    plt.ylabel('Score', fontsize=14, labelpad=10)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(fontsize=12, framealpha=0.9)
    plt.xticks(rotation=45, size=12)
    plt.yticks(size=12)
    plt.tight_layout()
    
    plt.savefig(os.path.join(output_dir, 'temporal_progression.png'), 
                dpi=300, bbox_inches='tight')
    plt.close()

def create_comparative_tables(single_label_results, multi_label_results, output_dir):
    """Create detailed comparative tables"""
    # Per-class performance comparison
    class_metrics = defaultdict(list)
    models = list(single_label_results.keys())
    
    for label in next(iter(single_label_results.values())).keys():
        if label != 'macro_avg':
            row = {'Label': label}
            for model in models:
                for metric in ['Precision', 'Recall', 'F1-score']:
                    row[f'{model}_{metric}'] = single_label_results[model][label][metric]
            class_metrics['rows'].append(row)
    
    df_class = pd.DataFrame(class_metrics['rows'])
    df_class.to_csv(os.path.join(output_dir, 'class_performance_comparison.csv'), index=False)
    
    # Model-to-model improvement table
    improvement_data = []
    baseline = models[0]
    metrics = ['Hamming_Loss', 'Subset_Accuracy', 'Partial_Match_Accuracy']
    
    for model in models[1:]:
        row = {'Model': model}
        for metric in metrics:
            baseline_value = multi_label_results[baseline]['overall'][metric]
            current_value = multi_label_results[model]['overall'][metric]
            improvement = ((current_value - baseline_value) / baseline_value) * 100
            row[f'{metric}_improvement'] = improvement
        improvement_data.append(row)
    
    df_improvement = pd.DataFrame(improvement_data)
    df_improvement.to_csv(os.path.join(output_dir, 'model_improvements.csv'), index=False)
    
def calculate_temporal_metrics(single_label_results, multi_label_results):
    """Calculate comprehensive temporal metrics"""
    temporal_data = {}
    models = ['ConflLlama-2005', 'ConflLlama-2010', 'ConflLlama-2014', 'ConflLlama-full']
    
    # Multi-label metrics
    for metric in ['Hamming_Loss', 'Subset_Accuracy', 'Partial_Match_Accuracy']:
        values = []
        for model in models:
            if model in multi_label_results and 'overall' in multi_label_results[model]:
                values.append(multi_label_results[model]['overall'][metric])
        temporal_data[metric] = values
    
    # Macro metrics
    for metric in ['Precision', 'Recall', 'F1-score', 'Accuracy']:
        values = []
        for model in models:
            if model in single_label_results and 'macro_avg' in single_label_results[model]:
                values.append(single_label_results[model]['macro_avg'][metric])
        temporal_data[f'Macro_{metric}'] = values
    
    return pd.DataFrame(temporal_data, index=models)

def plot_temporal_metrics(temporal_df, output_dir):
    """Create separate plots for different metric groups"""
    metric_groups = {
        'multilabel': ['Hamming_Loss', 'Subset_Accuracy', 'Partial_Match_Accuracy'],
        'macro': ['Macro_Precision', 'Macro_Recall', 'Macro_F1-score', 'Macro_Accuracy']
    }
    
    colors = get_model_colors()
    
    for group_name, metrics in metric_groups.items():
        plt.figure(figsize=(12, 6))
        
        for metric in metrics:
            plt.plot(temporal_df.index, temporal_df[metric], 
                    marker='o', label=metric.replace('_', ' '),
                    linewidth=2)
        
        plt.xlabel('Model Version')
        plt.ylabel('Score')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend()
        plt.xticks(rotation=45)
        plt.tight_layout()
        
        plt.savefig(os.path.join(output_dir, f'{group_name}_progression.png'), dpi=300, bbox_inches='tight')
        plt.close()

def analyze_per_class_temporal(single_label_results, unique_labels, output_dir):
    """Analyze temporal progression for each class"""
    models = ['ConflLlama-2005', 'ConflLlama-2010', 'ConflLlama-2014', 'ConflLlama-full']
    
    for label in unique_labels:
        plt.figure(figsize=(14, 8))  # Increased figure size
        metrics = ['Precision', 'Recall', 'F1-score']
        
        for metric in metrics:
            values = [single_label_results[model][label][metric] for model in models]
            plt.plot(models, values, marker='o', label=metric, 
                    linewidth=3, markersize=8)  # Increased line width and marker size
        
        plt.xlabel('Model Version', fontsize=14, labelpad=10)
        plt.ylabel('Score', fontsize=14, labelpad=10)
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.legend(fontsize=12, framealpha=0.9)
        plt.xticks(rotation=45, size=12)
        plt.yticks(size=12)
        plt.tight_layout()
        
        plt.savefig(os.path.join(output_dir, f'temporal_progression_{sanitize_filename(label)}.png'), 
                    dpi=300, bbox_inches='tight')
        plt.close()

def plot_temporal_heatmap(single_label_results, unique_labels, output_dir):
    """Create heatmap showing performance improvement across classes"""
    models = ['ConflLlama-2005', 'ConflLlama-2010', 'ConflLlama-2014', 'ConflLlama-full']
    
    for metric in ['F1-score']:
        data = []
        for label in unique_labels:
            row = [single_label_results[model][label][metric] for model in models]
            data.append(row)
        
        plt.figure(figsize=(12, len(unique_labels) * 0.6))  # Adjusted figure size
        
        # Create heatmap with basic settings
        heatmap = sns.heatmap(data, annot=True, fmt='.3f', cmap='YlOrRd',
                             xticklabels=models, yticklabels=unique_labels,
                             annot_kws={'size': 10})  # Increased annotation size
        
        # Adjust colorbar label size after plotting
        heatmap.collections[0].colorbar.ax.set_ylabel(metric, size=14)
        
        plt.xlabel('Model Version', fontsize=14, labelpad=10)
        plt.ylabel('Attack Type', fontsize=14, labelpad=10)
        plt.xticks(rotation=45, ha='right', size=12)
        plt.yticks(rotation=0, size=12)
        plt.tight_layout()
        
        plt.savefig(os.path.join(output_dir, f'temporal_heatmap_{metric}.png'), 
                    dpi=300, bbox_inches='tight')
        plt.close()

def plot_macro_metrics_progression(temporal_df, output_dir):
    """Plot macro metrics progression with consistent styling"""
    plt.figure(figsize=(15, 10))  # Larger figure size
    
    metrics = ['Macro_Precision', 'Macro_Recall', 'Macro_F1-score', 'Macro_Accuracy']
    colors = ['#2ca02c', '#d62728', '#1f77b4', '#ff7f0e']  # Using consistent colors
    
    for metric, color in zip(metrics, colors):
        plt.plot(temporal_df.index, temporal_df[metric], 
                marker='o', label=metric.replace('Macro_', 'Macro '),
                color=color, linewidth=3, markersize=8)
    
    plt.xlabel('Model Version', fontsize=14, labelpad=10)
    plt.ylabel('Score', fontsize=14, labelpad=10)
    
    # Enhance grid
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.gca().set_axisbelow(True)
    
    # Enhance legend
    plt.legend(fontsize=12, framealpha=0.9)
    
    # Adjust ticks
    plt.xticks(rotation=45, ha='right', size=12)
    plt.yticks(size=12)
    
    # Set y-axis limits with some padding
    ymin = temporal_df[metrics].min().min() - 0.05
    ymax = temporal_df[metrics].max().max() + 0.05
    plt.ylim(ymin, ymax)
    
    plt.tight_layout()
    
    save_path = os.path.join(output_dir, 'macro_metrics_progression.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_temporal_f1_heatmap(single_label_results, unique_labels, output_dir):
    """Create enhanced heatmap showing F1-score progression across classes"""
    models = ['ConflLlama-2005', 'ConflLlama-2010', 'ConflLlama-2014', 'ConflLlama-full']
    
    # Prepare data
    data = []
    for label in unique_labels:
        row = [single_label_results[model][label]['F1-score'] for model in models]
        data.append(row)
    
    plt.figure(figsize=(12, len(unique_labels) * 0.5))
    
    # Create heatmap with consistent styling
    heatmap = sns.heatmap(data, 
                         annot=True, 
                         fmt='.3f', 
                         cmap='YlOrRd',  # Matching existing heatmap color scheme
                         xticklabels=models, 
                         yticklabels=unique_labels,
                         annot_kws={'size': 10},
                         cbar_kws={'label': 'F1-score'})
    
    # Adjust colorbar
    heatmap.collections[0].colorbar.ax.set_ylabel('F1-score', size=12)
    heatmap.collections[0].colorbar.ax.tick_params(labelsize=12)
    
    plt.xlabel('Model Version', fontsize=14, labelpad=10)
    plt.ylabel('Attack Type', fontsize=14, labelpad=10)
    
    # Adjust tick labels
    plt.xticks(rotation=45, ha='right', size=12)
    plt.yticks(rotation=0, size=12)
    
    plt.tight_layout()
    
    save_path = os.path.join(output_dir, 'temporal_f1_heatmap.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_macro_metrics_radar(temporal_df, output_dir):
    """Create a radar/spider plot showing metrics progression"""
    metrics = ['Macro_Precision', 'Macro_Recall', 'Macro_F1-score', 'Macro_Accuracy']
    models = temporal_df.index

    # Prepare the angle for each metric
    angles = np.linspace(0, 2*np.pi, len(metrics), endpoint=False)
    angles = np.concatenate((angles, [angles[0]]))  # Complete the circle
    
    fig = plt.figure(figsize=(15, 10))
    ax = fig.add_subplot(111, polar=True)
    
    # Use consistent colors for each model
    colors = ['#2ca02c', '#d62728', '#1f77b4', '#ff7f0e']
    
    for idx, model in enumerate(models):
        values = temporal_df.loc[model, metrics].values
        values = np.concatenate((values, [values[0]]))  # Complete the circle
        
        ax.plot(angles, values, 'o-', linewidth=3, label=model, color=colors[idx])
        ax.fill(angles, values, alpha=0.1, color=colors[idx])
    
    # Set the labels
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels([m.replace('Macro_', '') for m in metrics], size=14)
    
    # Enhance the plot
    ax.set_ylim(0.4, 0.8)
    ax.tick_params(axis='y', labelsize=12)
    ax.grid(True, alpha=0.3)
    
    plt.legend(loc='center left', bbox_to_anchor=(1.2, 0.5), fontsize=12)
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'macro_metrics_radar.png'), 
                dpi=300, bbox_inches='tight')
    plt.close()

def plot_macro_metrics_faceted(temporal_df, output_dir):
    """Create a faceted plot with individual trends for each metric"""
    metrics = ['Macro_Precision', 'Macro_Recall', 'Macro_F1-score', 'Macro_Accuracy']
    
    fig, axes = plt.subplots(2, 2, figsize=(20, 16))
    axes = axes.ravel()
    
    colors = ['#2ca02c', '#d62728', '#1f77b4', '#ff7f0e']
    
    for idx, metric in enumerate(metrics):
        ax = axes[idx]
        
        # Plot trend
        ax.plot(temporal_df.index, temporal_df[metric], 
                marker='o', linewidth=3, markersize=10,
                color=colors[idx])
        
        # Add value labels
        for x, y in zip(temporal_df.index, temporal_df[metric]):
            ax.annotate(f'{y:.3f}', 
                       (x, y), 
                       textcoords="offset points", 
                       xytext=(0,10), 
                       ha='center',
                       fontsize=12)
        
        # Customize each subplot
        ax.set_title(metric.replace('Macro_', ''), pad=20, fontsize=16)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.set_xlabel('Model Version', fontsize=14)
        ax.set_ylabel('Score', fontsize=14)
        ax.tick_params(axis='both', labelsize=12)
        ax.set_ylim(temporal_df[metric].min() - 0.05, 
                    temporal_df[metric].max() + 0.05)
        
        # Rotate x-labels
        ax.set_xticklabels(temporal_df.index, rotation=45, ha='right')
    
    plt.suptitle('Macro Metrics Progression Analysis', 
                 fontsize=20, y=1.02)
    plt.tight_layout()
    
    plt.savefig(os.path.join(output_dir, 'macro_metrics_faceted.png'), 
                dpi=300, bbox_inches='tight')
    plt.close()

def plot_macro_metrics_stacked(temporal_df, output_dir):
    """Create a stacked area plot showing metric relationships"""
    metrics = ['Macro_Precision', 'Macro_Recall', 'Macro_F1-score', 'Macro_Accuracy']
    colors = ['#2ca02c', '#d62728', '#1f77b4', '#ff7f0e']
    
    plt.figure(figsize=(15, 10))
    
    # Create stacked area plot
    plt.stackplot(range(len(temporal_df)), 
                 [temporal_df[m] for m in metrics],
                 labels=[m.replace('Macro_', '') for m in metrics],
                 colors=colors,
                 alpha=0.6)
    
    # Add individual lines for clarity
    for idx, metric in enumerate(metrics):
        plt.plot(range(len(temporal_df)), temporal_df[metric],
                color=colors[idx], linewidth=2, alpha=0.8)
        
        # Add value labels
        for x, y in enumerate(temporal_df[metric]):
            plt.annotate(f'{y:.3f}',
                        (x, y),
                        textcoords="offset points",
                        xytext=(0,5),
                        ha='center',
                        fontsize=12)
    
    plt.xticks(range(len(temporal_df)), temporal_df.index, 
               rotation=45, ha='right', fontsize=12)
    plt.yticks(fontsize=12)
    
    plt.xlabel('Model Version', fontsize=14, labelpad=10)
    plt.ylabel('Score', fontsize=14, labelpad=10)
    
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.legend(fontsize=12, loc='center left', 
              bbox_to_anchor=(1, 0.5))
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'macro_metrics_stacked.png'), 
                dpi=300, bbox_inches='tight')
    plt.close()
    
def main():
    try:
        start_time = time.time()
        print_progress("Starting analysis...")
        
        script_dir = os.path.dirname(os.path.abspath(__file__))
        output_dir = os.path.join(script_dir, "results")
        os.makedirs(output_dir, exist_ok=True)
        
        print_progress("Reading data file...")
        results_file = os.path.join(script_dir, "final_results_hf.co_shreyasmeher_ConflLlama_Q8_0_cutoff_20170101.csv")
        df = pd.read_csv(results_file, low_memory=False)
        
        # Updated model columns for the two models we want to compare
        model_cols = [
            'hf.co/shreyasmeher/ConflLlama-early-2005:latest_predictions',
            'hf.co/shreyasmeher/ConflLlama-mid-2010:latest_predictions',
            'hf.co/shreyasmeher/ConflLlama-late-2014:latest_predictions',
            'hf.co/shreyasmeher/ConflLlama:Q8_0_predictions'
        ]
        
        # Updated model name mapping
        model_name_mapping = {
            'hf.co/shreyasmeher/ConflLlama-early-2005:latest_predictions': 'ConflLlama-2005',
            'hf.co/shreyasmeher/ConflLlama-mid-2010:latest_predictions': 'ConflLlama-2010',
            'hf.co/shreyasmeher/ConflLlama-late-2014:latest_predictions': 'ConflLlama-2014',
            'hf.co/shreyasmeher/ConflLlama:Q8_0_predictions': 'ConflLlama-full'
        }
        
        print_progress("Extracting unique labels...")
        unique_labels = sorted(set(
            label for col in ['attacktype1_txt', 'attacktype2_txt', 'attacktype3_txt']
            for label in df[col].dropna().unique()
            if label != "Unknown"
        ))
        
        single_label_results = {}
        multi_label_results = {}
        
        # Get color scheme
        colors = get_model_colors()
        
        print_progress("\nCalculating ROC curves for all models...")
        roc_data, class_roc_data = calculate_roc_curves(df, model_cols, 'attacktype1_txt', unique_labels)
        
        if roc_data:
            print_progress("Plotting ROC curves...")
            try:
                plot_roc_curves(roc_data, model_name_mapping, colors, output_dir)
                plot_class_roc_curves(class_roc_data, model_name_mapping, colors, unique_labels, output_dir)
                print_progress("ROC curves saved successfully", indent=1)
            except Exception as e:
                print_progress(f"Error plotting ROC curves: {str(e)}", indent=1)
                print_progress(f"Traceback: {traceback.format_exc()}", indent=2)
        
        for model_col in model_cols:
            try:
                base_model_name = model_col
                display_name = model_name_mapping.get(base_model_name, base_model_name)
                print_progress(f"\nProcessing {display_name}...")
                
                primary_preds = []
                all_preds = []
                for pred in df[model_col]:
                    primary, all_labels = parse_prediction(pred)
                    primary_preds.append(primary)
                    all_preds.append(all_labels)
                
                print_progress("Extracting gold standard labels...", indent=1)
                true_primary = df['attacktype1_txt']
                true_multi = df.apply(get_gold_standard_labels, axis=1)
                
                print_progress("Calculating metrics...", indent=1)
                single_metrics = calculate_single_label_metrics(
                    true_primary, primary_preds, unique_labels
                )
                
                multi_metrics = calculate_multi_label_metrics(
                    true_multi, all_preds, unique_labels
                )
                
                single_label_results[display_name] = single_metrics
                multi_label_results[display_name] = multi_metrics
                
                plot_confusion_matrix(
                    true_primary, 
                    primary_preds, 
                    unique_labels,
                    f'Confusion Matrix - {display_name}',
                    os.path.join(output_dir, f'confusion_matrix_{sanitize_filename(display_name)}.png')
                )
                
            except Exception as e:
                print_progress(f"Error processing {display_name}: {str(e)}", indent=1)
                print_progress(f"Traceback: {traceback.format_exc()}", indent=2)
                continue
        
        if single_label_results and multi_label_results:
            print_progress("\nCreating enhanced visualizations...")
            
            # Calculate temporal metrics
            temporal_df = calculate_temporal_metrics(single_label_results, multi_label_results)
            
            # Create enhanced visualizations
            plot_macro_metrics_progression(temporal_df, output_dir)
            plot_temporal_f1_heatmap(single_label_results, unique_labels, output_dir)
            plot_macro_metrics_radar(temporal_df, output_dir)
            plot_macro_metrics_faceted(temporal_df, output_dir)
            plot_macro_metrics_stacked(temporal_df, output_dir)
            plot_temporal_f1_heatmap(single_label_results, unique_labels, output_dir)
            
            print_progress("Creating comparative tables...")
            create_comparative_tables(single_label_results, multi_label_results, output_dir)
            
            # Save metrics to CSV
            print_progress("Saving metrics to CSV...")
            summary_data = []
            for model, metrics in multi_label_results.items():
                if 'overall' in metrics:
                    summary_data.append({
                        'Model': model,
                        'Hamming Loss': metrics['overall']['Hamming_Loss'],
                        'Subset Accuracy': metrics['overall']['Subset_Accuracy'],
                        'Partial Match': metrics['overall']['Partial_Match_Accuracy'],
                        'Avg Labels (True)': metrics['overall']['Label_Cardinality_True'],
                        'Avg Labels (Pred)': metrics['overall']['Label_Cardinality_Predicted']
                    })
            
            df_summary = pd.DataFrame(summary_data)
            df_summary.to_csv(os.path.join(output_dir, 'multilabel_summary.csv'), index=False)
            
            # Create styled HTML summary
            styled_table = df_summary.style\
                .format(precision=3)\
                .background_gradient(cmap='YlOrRd', subset=df_summary.columns[1:])\
                .to_html()
            
            with open(os.path.join(output_dir, 'multilabel_summary.html'), 'w') as f:
                f.write(styled_table)
        
        end_time = time.time()
        execution_time = end_time - start_time
        print_progress(f"\nAnalysis complete! Total execution time: {execution_time:.2f} seconds")
        
    except Exception as e:
        print_progress(f"Fatal error in main execution: {str(e)}")
        print_progress(f"Traceback: {traceback.format_exc()}")
        raise

if __name__ == "__main__":
    main()